{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-05-27T04:54:35.458415Z",
     "start_time": "2018-05-27T04:54:35.299306Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-05-27T04:54:35.605623Z",
     "start_time": "2018-05-27T04:54:35.459565Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../src/')\n",
    "\n",
    "from model import UNet\n",
    "from dataset import SSSDataset\n",
    "from loss import DiscriminativeLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-05-27T04:54:35.608184Z",
     "start_time": "2018-05-27T04:54:35.606670Z"
    }
   },
   "outputs": [],
   "source": [
    "n_sticks = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-05-27T04:54:36.974833Z",
     "start_time": "2018-05-27T04:54:35.609077Z"
    }
   },
   "outputs": [],
   "source": [
    "# Model\n",
    "model = UNet().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-05-27T04:54:36.978309Z",
     "start_time": "2018-05-27T04:54:36.975999Z"
    }
   },
   "outputs": [],
   "source": [
    "# Dataset for train\n",
    "train_dataset = SSSDataset(train=True, n_sticks=n_sticks)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=4,\n",
    "                              shuffle=False, num_workers=0, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-05-27T04:54:36.981864Z",
     "start_time": "2018-05-27T04:54:36.979448Z"
    }
   },
   "outputs": [],
   "source": [
    "# Loss Function\n",
    "criterion_disc = DiscriminativeLoss(delta_var=0.5,\n",
    "                                    delta_dist=1.5,\n",
    "                                    norm=2,\n",
    "                                    usegpu=True).cuda()\n",
    "criterion_ce = nn.CrossEntropyLoss().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-05-27T04:54:36.986508Z",
     "start_time": "2018-05-27T04:54:36.983106Z"
    }
   },
   "outputs": [],
   "source": [
    "# Optimizer\n",
    "parameters = model.parameters()\n",
    "optimizer = optim.SGD(parameters, lr=0.01, momentum=0.9, weight_decay=0.001)\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,\n",
    "                                                 mode='min',\n",
    "                                                 factor=0.1,\n",
    "                                                 patience=10,\n",
    "                                                 verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-05-27T04:54:38.403441Z",
     "start_time": "2018-05-27T04:54:36.987711Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Train\n",
    "model_dir = Path('../model')\n",
    "\n",
    "best_loss = np.inf\n",
    "for epoch in range(300):\n",
    "    print(f'epoch : {epoch}')\n",
    "    disc_losses = []\n",
    "    ce_losses = []\n",
    "    for batched in train_dataloader:\n",
    "        images, sem_labels, ins_labels = batched\n",
    "        images = Variable(images).cuda()\n",
    "        sem_labels = Variable(sem_labels).cuda()\n",
    "        ins_labels = Variable(ins_labels).cuda()\n",
    "        model.zero_grad()\n",
    "\n",
    "        sem_predict, ins_predict = model(images)\n",
    "        loss = 0\n",
    "\n",
    "        # Discriminative Loss\n",
    "        disc_loss = criterion_disc(ins_predict,\n",
    "                                   ins_labels,\n",
    "                                   [n_sticks] * len(images))\n",
    "        loss += disc_loss\n",
    "        disc_losses.append(disc_loss.cpu().data.numpy()[0])\n",
    "\n",
    "        # Cross Entropy Loss\n",
    "        _, sem_labels_ce = sem_labels.max(1)\n",
    "        ce_loss = criterion_ce(sem_predict.permute(0, 2, 3, 1)\\\n",
    "                                   .contiguous().view(-1, 2),\n",
    "                               sem_labels_ce.view(-1))\n",
    "        loss += ce_loss\n",
    "        ce_losses.append(ce_loss.cpu().data.numpy()[0])\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    disc_loss = np.mean(disc_losses)\n",
    "    ce_loss = np.mean(ce_losses)\n",
    "    print(f'DiscriminativeLoss: {disc_loss:.4f}')\n",
    "    print(f'CrossEntropyLoss: {ce_loss:.4f}')\n",
    "    scheduler.step(disc_loss)\n",
    "    if disc_loss < best_loss:\n",
    "        best_loss = disc_loss\n",
    "        print('Best Model!')\n",
    "        modelname = 'model.pth'\n",
    "        torch.save(model.state_dict(), model_dir.joinpath(modelname))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
